#include <torch/csrc/jit/serialization/pickle.h>

#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_read.h>

namespace torch::jit {

namespace {

c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
  at::TypePtr type = nullptr;
  if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
    type = torch::getCustomClass(qn.qualifiedName());
  } else {
    // This is a regular type, fall back to the default type parser
    torch::jit::ScriptTypeParser parser;
    type = parser.parseType(qn.qualifiedName());
    return c10::StrongTypePtr(nullptr, std::move(type));
  }
  if (type == nullptr) {
    TORCH_CHECK(
        false,
        "Couldn't resolve type '{}', did you forget to add its build dependency?",
        qn.qualifiedName());
  }
  // Passing nullptr is a little bit sus, but should be fine:
  // 1. The lifetime of the class type is not tied to a specific
  // CompilationUnit
  //    but rather the global custom class registry.
  // 2. We will not access the `cu_` field and immediately discard this
  //    StrongTypePtr post-deserialization.
  return c10::StrongTypePtr(nullptr, std::move(type));
}

} // namespace

void pickle(
    std::function<void(const char* data_start, size_t data_len)> writer,
    const IValue& ivalue,
    std::vector<at::Tensor>* tensor_table) {
  Pickler pickler(std::move(writer), tensor_table, nullptr, nullptr);
  pickler.protocol();
  pickler.pushIValue(ivalue);
  pickler.stop();
}

std::vector<char> pickle(
    const IValue& ivalue,
    std::vector<at::Tensor>* tensor_table) {
  std::vector<char> data;

  pickle(
      [&](const char* bytes, size_t len) {
        data.insert(data.end(), bytes, bytes + len);
      },
      ivalue,
      tensor_table);

  return data;
}

// This has to live here instead of the C++ API to mirror torch.save since the
// mobile build excludes the C++ API
std::vector<char> pickle_save(const at::IValue& ivalue) {
#ifndef C10_MOBILE
  // Pickle the IValue into an array of bytes
  std::vector<char> pickle_data;
  Pickler pickler([&](const char* buf, size_t size) {
    pickle_data.insert(pickle_data.end(), buf, buf + size);
  });
  pickler.protocol();
  pickler.pushIValue(ivalue);
  pickler.stop();

  std::vector<char> container_data;
  container_data.reserve(pickle_data.size());

  caffe2::serialize::PyTorchStreamWriter writer(
      [&](const void* void_bytes, size_t len) {
        const char* bytes = reinterpret_cast<const char*>(void_bytes);
        container_data.insert(container_data.end(), bytes, bytes + len);
        return len;
      });

  // Write the generated bytes and the associated tensors into a data.pkl file
  // and data/0, data/1, data/2... files for each of the tensors
  writeArchiveAndTensors(
      "data",
      pickle_data.data(),
      pickle_data.size(),
      pickler.tensorData(),
      writer);
  return container_data;
#else
  TORCH_CHECK(
      false,
      "pickle_save not supported on mobile "
      "(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}

#ifndef C10_MOBILE
size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
    const {
  std::copy(
      data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
  return n;
}

size_t StringViewReader::read(
    uint64_t pos,
    void* buf,
    size_t n,
    const char* what) const {
  std::copy(
      data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
  return n;
}
#endif

IValue pickle_load(const std::vector<char>& data) {
  // Read in the pickle data
#ifndef C10_MOBILE
  caffe2::serialize::PyTorchStreamReader reader(
      std::make_unique<VectorReader>(data));

  return readArchiveAndTensors(
      "data",
      /*pickle_prefix=*/"",
      /*tensor_prefix=*/"",
      /*type_resolver=*/std::nullopt,
      /*obj_loader=*/std::nullopt,
      /*device=*/std::nullopt,
      reader);
#else
  TORCH_CHECK(
      false,
      "pickle_load not supported on mobile "
      "(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}

// A specialized version of pickle_load that can load custom objects.
c10::IValue pickle_load_obj(std::string_view data) {
#ifndef C10_MOBILE
  caffe2::serialize::PyTorchStreamReader reader(
      std::make_unique<torch::jit::StringViewReader>(data));
  return torch::jit::readArchiveAndTensors(
      "data",
      /*pickle_prefix=*/"",
      /*tensor_prefix=*/"",
      /*type_resolver=*/customClassResolver,
      /*obj_loader=*/torch::jit::ObjLoaderFunc,
      /*device=*/std::nullopt,
      reader);
#else
  TORCH_CHECK(
      false,
      "pickle_load not supported on mobile "
      "(see https://github.com/pytorch/pytorch/pull/30108)");
#endif
}

IValue unpickle(
    std::function<size_t(char*, size_t)> reader,
    TypeResolver type_resolver,
    c10::ArrayRef<at::Tensor> tensor_table,
    c10::TypePtr (*type_parser)(const std::string&),
    ObjLoader obj_loader) {
  Unpickler unpickler(
      std::move(reader),
      std::move(type_resolver),
      tensor_table,
      std::move(obj_loader),
      type_parser);
  return unpickler.parse_ivalue();
}

IValue unpickle(
    const char* data,
    size_t size,
    TypeResolver type_resolver,
    c10::ArrayRef<at::Tensor> tensor_table,
    c10::TypePtr (*type_parser)(const std::string&)) {
  return unpickle(
      data, size, nullptr, std::move(type_resolver), tensor_table, type_parser);
}

IValue unpickle(
    const char* data,
    size_t size,
    ObjLoader obj_loader,
    TypeResolver type_resolver,
    c10::ArrayRef<at::Tensor> tensor_table,
    c10::TypePtr (*type_parser)(const std::string&)) {
  size_t bytes_read = 0;
  return unpickle(
      [&](char* buffer, size_t len) -> size_t {
        if (bytes_read >= size) {
          return 0;
        }
        len = std::min(size - bytes_read, len);
        // Copy len bytes into buffer
        const char* start = data + bytes_read;
        std::memcpy(buffer, start, len);
        bytes_read += len;
        return len;
      },
      std::move(type_resolver),
      tensor_table,
      type_parser,
      std::move(obj_loader));
}

} // namespace torch::jit
